{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from Dataset import Trainset, Testset\n",
    "from train_valid_split import train_valid_split\n",
    "from scipy.io import loadmat\n",
    "from Model import Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "from sklearn.decomposition import PCA, FastICA\n",
    "from sklearn.naive_bayes import GaussianNB\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(10, whiten=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = loadmat('./sub_a.mat')\n",
    "\n",
    "signals = raw_data['responses']\n",
    "label = raw_data['is_stimulate']\n",
    "\n",
    "data = []\n",
    "target = []\n",
    "\n",
    "for i in range(12):\n",
    "    for j in range(85):\n",
    "        data.append(pca.fit_transform(signals[i, :, :, j].T).T)\n",
    "        target.append(label[i, j])\n",
    "\n",
    "data = np.array(data)\n",
    "target = np.array(target)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1020, 10, 64)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = loadmat('./sub_b_test.mat')\n",
    "\n",
    "signals = raw_data['responses']\n",
    "\n",
    "data = []\n",
    "\n",
    "for i in range(12):\n",
    "    for j in range(85):\n",
    "        data.append(signals[i, :, :, j].reshape(-1, 64))\n",
    "\n",
    "data = np.array(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(data,\n",
    "                                                   target,\n",
    "                                                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8303921568627451\n",
      "0.8333333333333334\n",
      "0.8352941176470589\n",
      "0.8333333333333334\n",
      "0.8303921568627451\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8303921568627451\n",
      "0.8333333333333334\n",
      "0.8303921568627451\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8313725490196079\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8362745098039216\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8352941176470589\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8313725490196079\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8343137254901961\n",
      "0.8343137254901961\n",
      "0.8333333333333334\n",
      "0.8343137254901961\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8343137254901961\n",
      "0.8352941176470589\n",
      "0.8333333333333334\n",
      "0.8294117647058824\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8382352941176471\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8333333333333334\n",
      "0.8333333333333334\n",
      "0.8323529411764706\n",
      "0.8343137254901961\n",
      "0.8333333333333334\n",
      "0.8343137254901961\n",
      "0.8333333333333334\n",
      "0.8343137254901961\n"
     ]
    }
   ],
   "source": [
    "classifiers = []\n",
    "\n",
    "for i in range(64):\n",
    "    gnb = GaussianNB()\n",
    "    bayes = gnb.fit(data[:,:,i], target)\n",
    "    print(gnb.score(data[:,:,i], target))\n",
    "    classifiers.append(bayes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "vote = np.zeros(data.shape[0])\n",
    "for step, bayes in enumerate(classifiers):\n",
    "    vote = vote + bayes.predict(data[:, :, step])\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "char = [['A', 'B', 'C', 'D', 'E', 'F'],\n",
    "        ['G', 'H', 'I', 'J', 'K', 'L'],\n",
    "        ['M', 'N', 'O', 'P', 'Q', 'R'],\n",
    "        ['S', 'T', 'U', 'V', 'W', 'X'],\n",
    "        ['Y', 'Z', '1', '2', '3', '4'],\n",
    "        ['5', '6', '7', '8', '9', '_']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      "[0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1.]\n",
      "[0. 0. 0. 0. 0. 0. 2. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 1. 0. 0. 1. 1. 2. 0. 0. 1.]\n",
      "[0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 1. 0.]\n",
      "[1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      "[0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0.]\n",
      "[0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      "[0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1.]\n",
      "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 1. 1. 0. 0. 3. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 0. 2. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      "[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 2. 0. 1. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.]\n",
      "[0. 2. 0. 0. 1. 0. 0. 1. 3. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0.]\n",
      "[0. 0. 0. 2. 2. 0. 0. 0. 0. 1. 1. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 1. 2. 0. 1. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      "[2. 2. 0. 0. 0. 2. 1. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      "[0. 0. 1. 0. 0. 0. 0. 0. 0. 2. 0. 0.]\n",
      "[0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 1. 0. 2. 1. 0. 0.]\n",
      "[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      "[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "series = []\n",
    "for i in range(0, data.shape[0], 12):\n",
    "    print(vote[i:i+12])\n",
    "    series.append(char[np.argmax(vote[i+6:i+12])][np.argmax(vote[i:i+6])])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "str = ''.join(series)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'AAEAAA5AGAAMAZAPDNSAZBBEXBAGAYS1ECAVCALGAAA6EYGNBRV5AASALPBIYAAJAMUJSMDABAAAA6BAAG5AA'"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_b = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR'\n",
    "real_b_train = 'VGREAAH8TVRHBYN_UGCOLO4EUERDOOHCIFOMDNU6LQCPKEIREKOYRQIDJXPBKOJDWZEUEWWFOEBHXTQTTZUMO'\n",
    "\n",
    "real_a_train = 'EAEVQTDOJG8RBRGONCEDHCTUIDBPUHMEM6OUXOCFOUKWA4VJEFRZROLHYNQDW_EKTLBWXEPOUIKZERYOOTHQI'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "counter = 0\n",
    "for i in range(85):\n",
    "    if str[i] == real_a_train[i]:\n",
    "        counter += 1\n",
    "\n",
    "counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch]",
   "language": "python",
   "name": "conda-env-torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
